Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add factorized llama model for testing. #604

Closed
wants to merge 1 commit into from

Conversation

rjpower
Copy link
Collaborator

@rjpower rjpower commented May 29, 2024

(Not intended for real review! But here's the current hacked up status of trying to get a layerwise trainer going in case it provides food for thought.) I can't for the life of me figure out how to get Github to acknowledge a file copy and show a diff: if there's a tip for how to do this with git, please let me know. Brain dump:

Random things:

  • I noticed that state.training_key isn't updated with the new_key when we take a step: https://github.com/stanford-crfm/levanter/blob/main/src/levanter/trainer.py#L495 . IIUC, we'll use the same randomness for every batch as a result. (I could be horribly wrong here). That said, it doesn't look like we use dropout or any other runtime- randomness, so it probably doesn't change anything either way!

Good things:

  • Setting up the factorized model was straightforward, especially thanks to the testing machinery: I could just iterate with pytest pretty quickly. The HF integration/state_dict was a bit harder (but definitely helped to have the roundtrip tests!). I banged my head on the layer linearization and trying to get the shapes aligned etc, but this pretty unusual setup though because the model weights radically diverge from the HF weights. In general I've found whacking around with the models a bit easier than mucking with the trainer setup.
  • One nice thing that happened was by using the FactorizedLinear everywhere, I could centralize the HF encoding logic and avoid the flatten/unflatten logic in each transformer class. (I have a more grandiose thought of how we could factor out and maybe separately register the HF machinery but I think I should wait on that until I prove I actually know what I'm doing... 😛 )

There are a lot of my bad decisions here:

  • I blundered a lot because I don't have much experience with JAX yet (!! yeah, I know)
  • I made a copy of Trainer since I knew I was going to need hack around to understand things and get it working. Now that it's kind of working, I feel like you could maybe pull this off via subclassing, but it feels on the edge of understandability. I don't have any good ideas, but I think maybe we could find some components of trainer to make a bit more "top-level" and then re-use in different training paradigms as needed. I'll think on this some more: I'm thinking there's maybe a world where Trainer -> SupervisedModelTrainer but some of the ideas/components can be used elsewhere. This might be overkill too!
  • There are dumb things like I pulled StepInfo and hooks, not for any good reason, but because I got a weird error around StepInfo and the JAX cache and I just wanted to get things working and then revisit.
  • I had some confusion around getting the layerwise loss working correctly. It took me a bit to realize the default loss is hidden in the model and then I got tripped up by not setting the computation_axes correctly on my loss. This was all me: if I had bothered to read Trainer carefully I'd likely have understood the flow, but this part feels like something that would be easy to "lift" and share across different training paradigms.
  • JAX/XLA was wickedly slow when I tried to compute loss & gradients for each layer incrementally, but at least for my small tests it did okay when I summed the individual layerwise losses: in theory I think this should give the same gradients/training behavior, but 🤷 . I disabled scan_layers to make it easy for me to invoke the layers individually, but from our discussion that cripples performance: I wonder if I need to be more clever here?

@rjpower rjpower force-pushed the factor-llama branch 2 times, most recently from 6c5679c to 3a84009 Compare May 31, 2024 21:37
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry it took so long! Super busy week!

src/levanter/layerwise_trainer.py Outdated Show resolved Hide resolved

_loss_fn = hax.filter_checkpoint(_per_layer_loss)

loss, _ = hax.fold(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably the best way to do this. You can of course use scan if you wanted to keep per-layer losses for logging or weighting or something

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this took a bit of back and forth to figure out but seems sensible now.

Initially I used an unrolled loop, which "worked" on a tiny model but effectively stalled in compilation for the larger model sizes.
When I switched to scan, I ran into a few issues, most the normal JAX/TPU stuff, which shouldn't affect most regular models:

  • The is_scanned magic for hax.fold was trying to scan over the initial values in addition to the layers (because it by default tries to glom any namedarray into the scan). Maybe it's worth exempting the first argument from that logic?
  • It took a while to realize I had to enable gradient checkpointing. Everything "works" on a CPU test, but just OOMs the process. Trying it on a TPU yields XLA errors messages about trying to allocate >1TB of memory and asking for magic flags to actually tell you the large tensors. (And of course, when you provide those flags, it doesn't change anything...). I ended up needing to reduce the size of the model to convince XLA I was close enough to fitting for it to actually bother to do the allocation analysis and cough up the errors...

I'm also obviously special-casing to the stacked variant of the transformer, which feels a little gross, but... I suspect it's not worth generalizing the hof.py modules to support this.

src/levanter/layerwise_trainer.py Outdated Show resolved Hide resolved
src/levanter/layerwise_trainer.py Show resolved Hide resolved
d[prefix + ".down_proj.weight"] = down_proj
d[prefix + ".up_proj.weight"] = up_proj

d.update(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i need to finish the branch but this should be very simple once it's merged. You'll have to do your lowrank approximation manually but otherwise all of this will be handled for you.

In case you're curious/want to provide feedback, the branches are:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's much nicer indeed. I like the separation of determining the saveable state and then transforming it to HF, and having Stacked et. al handle more of the load.

The linearization and out_first logic was definitely tedious to work with when I was trying to get this working. I ended up creating one of those weird models where every dimension is a prime number to figure out what was going where :). (It didn't help that I couldn't figure out the real shape that SVD was outputting...)

return down_proj, up_proj


class FactorizedLinear(StateDictSerializationMixin, eqx.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so one thing you could do is do all this as a tree transformation/model surgery, similar to how we do it for Lora.

Basically you load the model as normal, then you do something like:

def replace_linear(x):
    if not isinstance(x, Linear): return x
    
    up_proj, down_proj = low_rank_approximation(x)
    return LowRankLinear(up_proj, down_proj)

modified_model = jax.tree_util.tree_map(replace_linear, model, is_leaf=lambda x: isinstance(x, Linear))

I'm pretty sure you could delete most of this file if you did this.

This version only works if use_scan_layers is false. You have to do some fanciness (that we do in Lora) for scan layers to essentially vmap the replace_linear layer whenever you detect a stacked. (Maybe I should make a stacked_aware_tree_map or something)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a great idea. You'd lose a little flexibility in the size of your low-rank layers, but it would be nice to work across all models.

I'll try that out once I have something working end-to-end. (I don't mind starting with this and throwing it away, since it's a bit easier to debug this way than also debugging any wonky transform logic I try to write...

@dlwh
Copy link
Member

dlwh commented Jun 1, 2024

Random things:

* I noticed that `state.training_key` isn't updated with the `new_key` when we take a step: https://github.com/stanford-crfm/levanter/blob/main/src/levanter/trainer.py#L495 . IIUC, we'll use the same randomness for every batch as a result. (I could be horribly wrong here). That said, it doesn't look like we use dropout or any other runtime- randomness, so it probably doesn't change anything either way!

Oops, that got lost in my last attempt to refactor. I should have a test for that.

Dropout is super slow on TPU so I tend to avoid it...

Good things:

* Setting up the factorized model was straightforward, especially thanks to the testing machinery: I could just iterate with pytest pretty quickly. The HF integration/state_dict was a bit harder (but definitely helped to have the roundtrip tests!). I banged my head on the layer linearization and trying to get the shapes aligned etc, but this pretty unusual setup though because the model weights radically diverge from the HF weights. In general I've found whacking around with the models a bit easier than mucking with the trainer setup.

Noted, the branches in Haliax/Levanter should make most of the "bookkeeping" parts go away, and I think if you moved to a "tree transformation" based approach, I think you shouldn't have to modify the core Llama code at all. (I think?!?)

* One nice thing that happened was by using the FactorizedLinear everywhere, I could centralize the HF encoding logic and avoid the flatten/unflatten logic in each transformer class. (I have a more grandiose thought of how we could factor out and maybe separately register the HF machinery but I think I should wait on that until I prove I actually know what I'm doing... 😛 )

I'm all ears (though please do see the twin state dict branches and lmk what you think!)

There are a lot of my bad decisions here:

* I blundered a lot because I don't have much experience with JAX yet (!! yeah, I know)

* I made a copy of Trainer since I knew I was going to need hack around to understand things and get it working. Now that it's kind of working, I feel like you could _maybe_ pull this off via subclassing, but it feels on the edge of understandability. I don't have any good ideas, but I think maybe we could find some components of trainer to make a bit more "top-level" and then re-use in different training paradigms as needed. I'll think on this some more: I'm thinking there's maybe a world where Trainer -> SupervisedModelTrainer but some of the ideas/components can be used elsewhere. This might be overkill too!

* There are dumb things like I pulled StepInfo and hooks, not for any good reason, but because I got a weird error around StepInfo and the JAX cache and I just wanted to get things working and then revisit.

* I had some confusion around getting the layerwise loss working correctly. It took me a bit to realize the default loss is hidden in the model and then I got tripped up by not setting the computation_axes correctly on my loss. This was all me: if I had bothered to read Trainer carefully I'd likely have understood the flow, but this part feels like something that would be easy to "lift" and share across different training paradigms.

* JAX/XLA was wickedly slow when I tried to compute loss & gradients for each layer incrementally, but at least for my small tests it did okay when I summed the individual layerwise losses: in theory I think this should give the same gradients/training behavior, but 🤷 . I disabled `scan_layers` to make it easy for me to invoke the layers individually, but from our discussion that cripples performance: I wonder if I need to be more clever here?

@dlwh
Copy link
Member

dlwh commented Jun 1, 2024

There are a lot of my bad decisions here:

* I blundered a lot because I don't have much experience with JAX yet (!! yeah, I know)

We all do! I think JAX needs better docs / examples for the "not a complete beginner but not a JAX whisperer"

* I made a copy of Trainer since I knew I was going to need hack around to understand things and get it working. Now that it's kind of working, I feel like you could _maybe_ pull this off via subclassing, but it feels on the edge of understandability. I don't have any good ideas, but I think maybe we could find some components of trainer to make a bit more "top-level" and then re-use in different training paradigms as needed. I'll think on this some more: I'm thinking there's maybe a world where Trainer -> SupervisedModelTrainer but some of the ideas/components can be used elsewhere. This might be overkill too!

Yeah... still trying to figure this out. some bits like the batch loader helpers and some of the mixed precision/sharding logic should be "lego-ified" and split out. Probably other stuff too. I think my goal should be to make copy-paste not feel bad, if that makes sense.

* There are dumb things like I pulled StepInfo and hooks, not for any good reason, but because I got a weird error around StepInfo and the JAX cache and I just wanted to get things working and then revisit.

hrm. please send me a stack trace if you figure it out

* I had some confusion around getting the layerwise loss working correctly. It took me a bit to realize the default loss is hidden in the model and then I got tripped up by not setting the computation_axes correctly on my loss. This was all me: if I had bothered to read Trainer carefully I'd likely have understood the flow, but this part feels like something that would be easy to "lift" and share across different training paradigms.

Ah yeah, I don't love that. I'd like to lego-ify that too

* JAX/XLA was wickedly slow when I tried to compute loss & gradients for each layer incrementally, but at least for my small tests it did okay when I summed the individual layerwise losses: in theory I think this should give the same gradients/training behavior, but 🤷 . I disabled `scan_layers` to make it easy for me to invoke the layers individually, but from our discussion that cripples performance: I wonder if I need to be more clever here?

hrm. I do think the timings in the compiler are pretty carefully tuned for the "standard" case and maybe you were tripping it up? I do wish they'd offered PGO for TPU like they're starting to do for GPU... though maybe that's not the problem here.

Copy link
Collaborator Author

@rjpower rjpower left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback!

I think things are working, mostly, on the training side now: I can at least run steps. Unfortunately when I try to init the models from HF, things get stuck again for some reason. I'll need to add some more logging to see if I can figure out what's triggering that.

src/levanter/layerwise_trainer.py Show resolved Hide resolved
src/levanter/layerwise_trainer.py Outdated Show resolved Hide resolved

_loss_fn = hax.filter_checkpoint(_per_layer_loss)

loss, _ = hax.fold(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this took a bit of back and forth to figure out but seems sensible now.

Initially I used an unrolled loop, which "worked" on a tiny model but effectively stalled in compilation for the larger model sizes.
When I switched to scan, I ran into a few issues, most the normal JAX/TPU stuff, which shouldn't affect most regular models:

  • The is_scanned magic for hax.fold was trying to scan over the initial values in addition to the layers (because it by default tries to glom any namedarray into the scan). Maybe it's worth exempting the first argument from that logic?
  • It took a while to realize I had to enable gradient checkpointing. Everything "works" on a CPU test, but just OOMs the process. Trying it on a TPU yields XLA errors messages about trying to allocate >1TB of memory and asking for magic flags to actually tell you the large tensors. (And of course, when you provide those flags, it doesn't change anything...). I ended up needing to reduce the size of the model to convince XLA I was close enough to fitting for it to actually bother to do the allocation analysis and cough up the errors...

I'm also obviously special-casing to the stacked variant of the transformer, which feels a little gross, but... I suspect it's not worth generalizing the hof.py modules to support this.

d[prefix + ".down_proj.weight"] = down_proj
d[prefix + ".up_proj.weight"] = up_proj

d.update(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's much nicer indeed. I like the separation of determining the saveable state and then transforming it to HF, and having Stacked et. al handle more of the load.

The linearization and out_first logic was definitely tedious to work with when I was trying to get this working. I ended up creating one of those weird models where every dimension is a prime number to figure out what was going where :). (It didn't help that I couldn't figure out the real shape that SVD was outputting...)

return down_proj, up_proj


class FactorizedLinear(StateDictSerializationMixin, eqx.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a great idea. You'd lose a little flexibility in the size of your low-rank layers, but it would be nice to work across all models.

I'll try that out once I have something working end-to-end. (I don't mind starting with this and throwing it away, since it's a bit easier to debug this way than also debugging any wonky transform logic I try to write...

@rjpower
Copy link
Collaborator Author

rjpower commented Jun 2, 2024

Yeah... still trying to figure this out. some bits like the batch loader helpers and some of the mixed precision/sharding logic should be "lego-ified" and split out. Probably other stuff too. I think my goal should be to make copy-paste not feel bad, if that makes sense.

+1, and overall I think things have been fine: if it's been slow getting things working, it's because I hit some indecipherable XLA or JAX error vs having to hack around with Levanter. If this was just another model I don't see any reason I'd presumably be able to reuse the default training setup easily. But I suspect there are some parts here we can make easier to use while preserving the existing workflow.

hrm. please send me a stack trace if you figure it out

I'll let you know if I can reproduce. I couldn't understand it because all that stuff is outside of the JIT scope, but I thought, "no need for StepInfo until I can take a step..."

  • JAX/XLA was wickedly slow when I tried to compute loss & gradients for each layer incrementally, but at least for my small tests it did okay when I summed the individual layerwise losses: in theory I think this should give the same gradients/training behavior, but 🤷 . I disabled scan_layers to make it easy for me to invoke the layers individually, but from our discussion that cripples performance: I wonder if I need to be more clever here?

hrm. I do think the timings in the compiler are pretty carefully tuned for the "standard" case and maybe you were tripping it up? I do wish they'd offered PGO for TPU like they're starting to do for GPU... though maybe that's not the problem here.

Yeah there's certainly some tuning issues with the compiler if you don't take the right approach. I didn't investigate it too much in this case though: it might just have produced a giant program and stalled trying to compile it...

@rjpower rjpower force-pushed the factor-llama branch 4 times, most recently from 6d40fc2 to e5ec7ca Compare June 3, 2024 22:54
@dlwh
Copy link
Member

dlwh commented Jun 5, 2024

The is_scanned magic for hax.fold was trying to scan over the initial values in addition to the layers (because it by default tries to glom any namedarray into the scan). Maybe it's worth exempting the first argument from that logic?

Can you open an issue. I agree this is better

@dlwh dlwh closed this Sep 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants